{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Create TFJob using Kubeflow Training SDK"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "This is a sample for Kubeflow Training SDK `kubeflow-training`.\n",
    "\n",
    "The notebook shows how to use Kubeflow TFJob SDK to create, get, wait, check and delete TFJob."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Kubeflow Training Python SDKs\n",
    "\n",
    "You need to install Kubeflow Training SDK to run this Notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO (andreyvelich): Change to release version when SDK with the new APIs is published.\n",
    "!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from kubernetes.client import V1PodTemplateSpec\n",
    "from kubernetes.client import V1ObjectMeta\n",
    "from kubernetes.client import V1PodSpec\n",
    "from kubernetes.client import V1Container\n",
    "\n",
    "\n",
    "from kubeflow.training import KubeflowOrgV1ReplicaSpec\n",
    "from kubeflow.training import KubeflowOrgV1TFJob\n",
    "from kubeflow.training import KubeflowOrgV1TFJobSpec\n",
    "from kubeflow.training import KubeflowOrgV1RunPolicy\n",
    "from kubeflow.training import TrainingClient\n",
    "\n",
    "from kubeflow.training import constants"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Define TFJob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "The demo runs Tensorflow MNIST example with 2 workers, chief, and parameter server for TFJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "name = \"mnist\"\n",
    "namespace = \"kubeflow-user-example-com\"\n",
    "container_name = \"tensorflow\"\n",
    "\n",
    "container = V1Container(\n",
    "    name=container_name,\n",
    "    image=\"gcr.io/kubeflow-ci/tf-mnist-with-summaries:1.0\",\n",
    "    command=[\n",
    "        \"python\",\n",
    "        \"/var/tf_mnist/mnist_with_summaries.py\",\n",
    "        \"--log_dir=/train/logs\", \"--learning_rate=0.01\",\n",
    "        \"--batch_size=150\"\n",
    "        ]\n",
    ")\n",
    "\n",
    "worker = KubeflowOrgV1ReplicaSpec(\n",
    "    replicas=2,\n",
    "    restart_policy=\"Never\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "chief = KubeflowOrgV1ReplicaSpec(\n",
    "    replicas=1,\n",
    "    restart_policy=\"Never\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "ps = KubeflowOrgV1ReplicaSpec(\n",
    "    replicas=1,\n",
    "    restart_policy=\"Never\",\n",
    "    template=V1PodTemplateSpec(\n",
    "        spec=V1PodSpec(\n",
    "            containers=[container]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "\n",
    "tfjob = KubeflowOrgV1TFJob(\n",
    "    api_version=constants.API_VERSION,\n",
    "    kind=constants.TFJOB_KIND,\n",
    "    metadata=V1ObjectMeta(name=\"mnist\",namespace=namespace),\n",
    "    spec=KubeflowOrgV1TFJobSpec(\n",
    "        run_policy=KubeflowOrgV1RunPolicy(clean_pod_policy=\"None\"),\n",
    "        tf_replica_specs={\"Worker\": worker,\n",
    "                          \"Chief\": chief,\n",
    "                          \"PS\": ps}\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create TFJob\n",
    "\n",
    "You have to create Training Client to deploy your TFJob in you cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "TFJob kubeflow-user-example-com/mnist has been created\n"
     ]
    }
   ],
   "source": [
    "# Namespace and Job kind will be reused in every APIs.\n",
    "training_client = TrainingClient(namespace=namespace, job_kind=constants.TFJOB_KIND)\n",
    "training_client.create_job(tfjob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the Created TFJob\n",
    "\n",
    "You can verify the created TFJob status."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'completion_time': None,\n",
       " 'conditions': [{'last_transition_time': datetime.datetime(2023, 9, 8, 21, 42, 34, tzinfo=tzutc()),\n",
       "                 'last_update_time': datetime.datetime(2023, 9, 8, 21, 42, 34, tzinfo=tzutc()),\n",
       "                 'message': 'TFJob mnist is created.',\n",
       "                 'reason': 'TFJobCreated',\n",
       "                 'status': 'True',\n",
       "                 'type': 'Created'},\n",
       "                {'last_transition_time': datetime.datetime(2023, 9, 8, 21, 42, 35, tzinfo=tzutc()),\n",
       "                 'last_update_time': datetime.datetime(2023, 9, 8, 21, 42, 35, tzinfo=tzutc()),\n",
       "                 'message': 'TFJob kubeflow-user-example-com/mnist is running.',\n",
       "                 'reason': 'TFJobRunning',\n",
       "                 'status': 'True',\n",
       "                 'type': 'Running'}],\n",
       " 'last_reconcile_time': None,\n",
       " 'replica_statuses': {'Chief': {'active': 1,\n",
       "                                'failed': None,\n",
       "                                'label_selector': None,\n",
       "                                'selector': None,\n",
       "                                'succeeded': None},\n",
       "                      'PS': {'active': 1,\n",
       "                             'failed': None,\n",
       "                             'label_selector': None,\n",
       "                             'selector': None,\n",
       "                             'succeeded': None},\n",
       "                      'Worker': {'active': 2,\n",
       "                                 'failed': None,\n",
       "                                 'label_selector': None,\n",
       "                                 'selector': None,\n",
       "                                 'succeeded': None}},\n",
       " 'start_time': datetime.datetime(2023, 9, 8, 21, 42, 34, tzinfo=tzutc())}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_client.get_job(name).status"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the TFJob Conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'last_transition_time': datetime.datetime(2023, 9, 8, 21, 42, 34, tzinfo=tzutc()),\n",
       "  'last_update_time': datetime.datetime(2023, 9, 8, 21, 42, 34, tzinfo=tzutc()),\n",
       "  'message': 'TFJob mnist is created.',\n",
       "  'reason': 'TFJobCreated',\n",
       "  'status': 'True',\n",
       "  'type': 'Created'},\n",
       " {'last_transition_time': datetime.datetime(2023, 9, 8, 21, 42, 35, tzinfo=tzutc()),\n",
       "  'last_update_time': datetime.datetime(2023, 9, 8, 21, 42, 35, tzinfo=tzutc()),\n",
       "  'message': 'TFJob kubeflow-user-example-com/mnist is running.',\n",
       "  'reason': 'TFJobRunning',\n",
       "  'status': 'True',\n",
       "  'type': 'Running'}]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_client.get_job_conditions(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Wait Until TFJob Finishes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "training_client.wait_for_job_conditions(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Verify if TFJob is Succeeded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "training_client.is_job_succeeded(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the TFJob Training Logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "training_client.get_job_logs(name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Delete the TFJob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "training_client.delete_job(name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
